{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run upon export from spreadsheet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from astroquery.mast import Catalogs\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "pd.set_option('display.max_columns', 50)\n",
    "\n",
    "def generate_data(include_toi):\n",
    "    tces_file = '/mnt/tess/labels/tce_bls_to_y3.csv'\n",
    "    labels_file = '/mnt/tess/labels/labels_vetting_v1.csv'\n",
    "    corrections_file = '/mnt/tess/labels/e_label_ephemerides.csv'\n",
    "\n",
    "    tce_table = pd.read_csv(tces_file, header=0, low_memory=False).set_index('tic_id')\n",
    "    tce_table = tce_table.drop(columns=['Unnamed: 0'])\n",
    "    tce_table['Duration'] /= 24.0\n",
    "\n",
    "    joined_table = tce_table\n",
    "    \n",
    "    corrections_table = pd.read_csv(corrections_file, header=0, low_memory=False);\n",
    "    corrections_table['tic_id'] = corrections_table['tic']\n",
    "    corrections_table = corrections_table.set_index('tic_id')\n",
    "    corrections_table['dur'] /= 24.0\n",
    "    corrections_table['dep'] *= 1e6\n",
    "\n",
    "    joined_table = joined_table.join(corrections_table, on='tic_id', how='left')\n",
    "    joined_table['Epoc'] = np.where(joined_table['epo'].isna(), joined_table['Epoc'], joined_table['epo'])\n",
    "    joined_table['Duration'] = np.where(joined_table['dur'].isna(), joined_table['Duration'], joined_table['dur'])\n",
    "    joined_table['Period'] = np.where(joined_table['per'].isna(), joined_table['Period'], joined_table['per'])\n",
    "    joined_table['Transit_Depth'] = np.where(joined_table['dep'].isna(), joined_table['Transit_Depth'], joined_table['dep'])\n",
    "\n",
    "    joined_table = joined_table.reset_index()[[\n",
    "        'tic_id', 'RA', 'Dec', 'Tmag', 'Epoc', 'Period', 'Duration',\n",
    "        'Transit_Depth', 'Sectors', 'star_rad', 'star_mass', 'teff',\n",
    "        'logg', 'SN', 'Qingress'\n",
    "    ]]\n",
    "\n",
    "\n",
    "    labels_table = pd.read_csv(labels_file, header=0, low_memory=False)\n",
    "    labels_table = labels_table.drop(columns=['261262721'])\n",
    "    labels_table['tic_id'] = labels_table['TIC ID']\n",
    "\n",
    "    disps = ['e', 'p', 'n', 'b', 't', 'u']\n",
    "    users = ['ch', 'et', 'md', 'as', 'mk']\n",
    "\n",
    "    for d in disps:\n",
    "        labels_table[f'disp_{d}'] = 0\n",
    "\n",
    "    def set_labels(row):\n",
    "        a = ~row.isna()\n",
    "        if a['Final']:\n",
    "            row[f'disp_{row[\"Final\"][0]}'] = 1\n",
    "            row[f'disp_{row[\"Final\"][1]}'] = 1\n",
    "        else:\n",
    "            for user in users:\n",
    "                if a[user] and row[user]:\n",
    "                    row[f'disp_{row[user][0]}'] += 1\n",
    "                    row[f'disp_{row[user][1]}'] += 1\n",
    "\n",
    "        return row\n",
    "\n",
    "    labels_table = labels_table.apply(set_labels, axis=1)\n",
    "    labels_table = labels_table[['tic_id', 'Split'] + [f'disp_{d}' for d in disps]]\n",
    "\n",
    "    joined_table = joined_table.set_index('tic_id')\n",
    "    labels_table = labels_table.set_index('tic_id')\n",
    "    joined_table = joined_table.join(labels_table, on='tic_id', how='inner')\n",
    "\n",
    "\n",
    "    toi = pd.read_csv('/mnt/tess/labels/tce_toi_vetting_p+eb.csv', header=0, low_memory=False).set_index('tic_id')\n",
    "\n",
    "    np.random.seed(1117)\n",
    "    toi['rand'] = np.random.randint(0, 100, [len(toi)])\n",
    "    toi['Split'] = toi.apply(lambda r: 'train' if r['rand'] < 80 else 'test'if r['rand'] >= 90 else 'val', axis=1)\n",
    "    toi = toi.drop(columns=['rand'])\n",
    "    \n",
    "    if include_toi:\n",
    "        # Trust the curated labels. This might help us to detect any inconsisitencies in TOI.\n",
    "        toi = toi[~toi.index.isin(joined_table.index.values)]\n",
    "        joined_table = joined_table.append(toi)\n",
    "\n",
    "\n",
    "    print(f'Total entries: {len(joined_table)}')\n",
    "    joined_table = joined_table[\n",
    "        sum(joined_table[f'disp_{d}'] for d in disps) > 0\n",
    "    ]\n",
    "    print(f'Total labeled entries: {len(joined_table)}')\n",
    "\n",
    "\n",
    "    all_table = joined_table\n",
    "\n",
    "    t_train = joined_table[joined_table['Split'] == 'train']\n",
    "    t_val = joined_table[joined_table['Split'] == 'val']\n",
    "    t_test = joined_table[joined_table['Split'] == 'test']\n",
    "    print(f'Split sizes. Train: {len(t_train)}; Valid: {len(t_val)}; Test: {len(t_test)}')\n",
    "    print(f'Duplicate TICs: {len(all_table.index.values) - len(set(all_table.index.values))}')\n",
    "\n",
    "    t_train = t_train.drop(columns=['Split'])\n",
    "    t_val = t_val.drop(columns=['Split'])\n",
    "    t_test = t_test.drop(columns=['Split'])\n",
    "    all_table = all_table.drop(columns=['Split'])\n",
    "    \n",
    "    return t_train, t_val, t_test, all_table\n",
    "\n",
    "\n",
    "t_train, t_val, t_test, all_table = generate_data(True)\n",
    "t_train.to_csv('/mnt/tess/astronet/tces-vetting-v7-toi-train.csv')\n",
    "t_val.to_csv('/mnt/tess/astronet/tces-vetting-v7-toi-val.csv')\n",
    "t_test.to_csv('/mnt/tess/astronet/tces-vetting-v7-toi-test.csv')\n",
    "\n",
    "all_table.to_csv('/mnt/tess/astronet/tces-vetting-all.csv')\n",
    "\n",
    "t_train, t_val, t_test, _ = generate_data(False)\n",
    "t_train.to_csv('/mnt/tess/astronet/tces-vetting-v7-train.csv')\n",
    "t_val.to_csv('/mnt/tess/astronet/tces-vetting-v7-val.csv')\n",
    "t_test.to_csv('/mnt/tess/astronet/tces-vetting-v7-test.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_table.sample(30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run once"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "def clean_tois():\n",
    "    toi = pd.read_csv('/mnt/tess/labels/toi_p.csv', header=0, low_memory=False)\n",
    "\n",
    "    toi = toi[toi['Period'] < 99999]\n",
    "\n",
    "    toi['disp_p'] = 1\n",
    "    toi['disp_e'] = 0\n",
    "    toi['disp_b'] = 0\n",
    "    toi['disp_t'] = 0\n",
    "    toi['disp_u'] = 0\n",
    "    toi['disp_n'] = 0\n",
    "\n",
    "    toi = toi[[\n",
    "        'tic_id',\n",
    "        'RA',\n",
    "        'Dec',\n",
    "        'Tmag',\n",
    "        'Epoc',\n",
    "        'Period',\n",
    "        'Duration',\n",
    "        'Transit_Depth',\n",
    "        'star_rad',\n",
    "        'star_mass',\n",
    "        'teff',\n",
    "        'logg',\n",
    "        'disp_e',\n",
    "        'disp_p',\n",
    "        'disp_n',\n",
    "        'disp_b',\n",
    "        'disp_t',\n",
    "        'disp_u',\n",
    "    ]]\n",
    "\n",
    "    toi['rand'] = toi.apply(lambda r: np.random.randint(0, 100), axis=1)\n",
    "\n",
    "    ebs = pd.read_csv('/mnt/tess/labels/ebs_ephemerides.csv', header=0, low_memory=False)\n",
    "    ebs['dur'] /= 24.0\n",
    "    ebs['dep'] *= 1e6\n",
    "\n",
    "    ebs['tic_id'] = ebs['tic']\n",
    "    ebs['RA'] = None\n",
    "    ebs['Dec'] = None\n",
    "    ebs['Tmag'] = None\n",
    "    ebs['Epoc'] = ebs['epo']\n",
    "    ebs['Period'] = ebs['per']\n",
    "    ebs['Duration'] = ebs['dur']\n",
    "    ebs['Transit_Depth'] = ebs['dep']\n",
    "    ebs['star_rad'] = None\n",
    "    ebs['star_mass'] = None\n",
    "    ebs['teff'] = None\n",
    "    ebs['logg'] = None\n",
    "\n",
    "    eb_labels = pd.read_csv('/mnt/tess/labels/additionalebs.csv', header=None, low_memory=False)\n",
    "    ebs = ebs.set_index('tic_id').join(eb_labels.set_index(0), how='inner').reset_index()\n",
    "    ebs['tic_id'] = ebs['index']\n",
    "\n",
    "    def set_label(row):\n",
    "        row['disp_p'] = 0\n",
    "        row['disp_e'] = 1\n",
    "        if row[3] == 'b':\n",
    "            row['disp_b'] = 1\n",
    "            row['disp_t'] = 0\n",
    "            row['disp_u'] = 0\n",
    "        elif row[3] == 't':\n",
    "            row['disp_b'] = 0\n",
    "            row['disp_t'] = 1\n",
    "            row['disp_u'] = 0\n",
    "        elif row[3] == 'u':\n",
    "            row['disp_b'] = 0\n",
    "            row['disp_t'] = 0\n",
    "            row['disp_u'] = 1\n",
    "        else:\n",
    "            raise ValueError(row)\n",
    "\n",
    "        row['disp_n'] = 0\n",
    "        return row\n",
    "\n",
    "    ebs = ebs.apply(set_label, axis=1)\n",
    "\n",
    "    ebs = ebs[[\n",
    "        'tic_id',\n",
    "        'RA',\n",
    "        'Dec',\n",
    "        'Tmag',\n",
    "        'Epoc',\n",
    "        'Period',\n",
    "        'Duration',\n",
    "        'Transit_Depth',\n",
    "        'star_rad',\n",
    "        'star_mass',\n",
    "        'teff',\n",
    "        'logg',\n",
    "        'disp_e',\n",
    "        'disp_p',\n",
    "        'disp_n',\n",
    "        'disp_b',\n",
    "        'disp_t',\n",
    "        'disp_u',\n",
    "    ]]\n",
    "    ebs = ebs.drop_duplicates()\n",
    "\n",
    "    ebs['rand'] = ebs.apply(lambda r: np.random.randint(0, 100), axis=1)\n",
    "\n",
    "    toi_pe = toi.append(ebs)\n",
    "    toi_pe.to_csv('/mnt/tess/labels/tce_toi_vetting_p+eb.csv')\n",
    "    \n",
    "    print('TOIs:', len(toi), 'EBs:', len(ebs), 'All:', len(toi_pe))\n",
    "\n",
    "clean_tois()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
